import numpy
import numba
from mdprop import Prop
def debug(*x):
    return
    #print(x)
GCONS = numpy.zeros((7, 7))
GCONS[2, :3] = numpy.asarray([15.0/8., -5.0/4., 3.0/8., ])
GCONS[3, :4] = numpy.asarray([35.0/16., -35.0/16., 21.0/16., -5.0/16., ])
GCONS[4, :5] = numpy.asarray([315.0/128., -105.0/32., 189.0/64., -45.0/32., 35.0/128., ])
GCONS[5, :6] = numpy.asarray([693.0/256., -1155.0/256., 693.0/128., -495.0/128., 385.0/256., -63.0/256., ])
GCONS[6, :7] = numpy.asarray([3003.0/1024., -3003.0/512., 9009.0/1024., -2145.0/256., 5005.0/1024., -819.0/512., 231.0/1024., ])

DGCONS = numpy.zeros((7, 6))
DGCONS[2, :2] = numpy.asarray([-5.0/2., 3.0/2., ])
DGCONS[3, :3] = numpy.asarray([-35.0/8., 21.0/4., -15.0/8., ])
DGCONS[4, :4] = numpy.asarray([-105.0/16., 189.0/16., -135.0/16., 35.0/16., ])
DGCONS[5, :5] = numpy.asarray([-1155.0/128., 693.0/32., -1485.0/64., 385.0/32., -315.0/128., ])
DGCONS[6, :6] = numpy.asarray([-3003.0/256., 9009.0/256., -6435.0/128., 5005.0/128., -4095.0/256., 693.0/256., ])

PHICOEF = {
    4: numpy.asarray([
        numpy.asarray([1.5, -2.5, 0., 1.])/1,
        numpy.asarray([-0.5, 2.5, -4., 2.])/1,
    ]),
    6: numpy.asarray([
        numpy.asarray([-5., 13., 5., -25., 0., 12.])/12.,
        numpy.asarray([5., -39., 105., -105., 10., 24.])/24.,
        numpy.asarray([-1., 13., -65., 155., -174., 72.])/24.,
    ]),
    8: numpy.asarray([
        numpy.asarray([7., -25., -35., 161., 28., -280., 0., 144.])/144.,
        numpy.asarray([-7., 75., -273., 315., 252., -630., 28., 240.])/240.,
        numpy.asarray([7., -125., 889., -3185., 5908., -4970., 756., 720.])/720.,
        numpy.asarray([-1., 25., -259., 1435., -4564., 8260., -7776., 2880.])/720.,
    ]),
    10: numpy.asarray([
        numpy.asarray([-9., 41., 126., -654., -441., 3129., 324., -5396., 0., 2880.])/2880.,
        numpy.asarray([3., -41., 180., -138., -945., 1911., 690., -3172., 72., 1440.])/1440.,
        numpy.asarray([-9., 205., -1872., 8610., -19845., 15645., 18342., -34540., 3384., 10080.])/10080.,
        numpy.asarray([9., -287., 3870., -28686., 126945., -339423., 523080., -397684., 71856., 40320.])/40320.,
        numpy.asarray([-1., 41., -726., 7266., -45129., 179529., -454544., 700204., -588240., 201600.])/40320.,
    ]),
}
DPHICOEF = {
    4: numpy.asarray([
        numpy.asarray([9, -10, 0])/2.,
        numpy.asarray([-3, 10, -8])/2.,
    ]),
    6: numpy.asarray([
        numpy.asarray([-25, 52, 15, -50, 0])/12.,
        numpy.asarray([25, -156, 315, -210, 10])/24.,
        numpy.asarray([-5, 52, -195, 310, -174])/24.,
    ]),
    8: numpy.asarray([
        numpy.asarray([49, -150, -175, 644, 84, -560, 0])/144.,
        numpy.asarray([-49, 450, -1365, 1260, 756, -1260, 28])/240.,
        numpy.asarray([49, -750, 4445, -12740, 17724, -9940, 756])/720.,
        numpy.asarray([-7, 150, -1295, 5740, -13692, 16520, -7776])/720.,
    ]),
    10: numpy.asarray([
        numpy.asarray([-81, 328, 882, -3924, -2205, 12516, 972, -10792, 0])/2880.,
        numpy.asarray([27, -328, 1260, -828, -4725, 7644, 2070, -6344, 72])/1440.,
        numpy.asarray([-81, 1640, -13104, 51660, -99225, 62580, 55026, -69080, 3384])/10080.,
        numpy.asarray([81, -2296, 27090, -172116, 634725, -1357692, 1569240, -795368, 71856])/40320.,
        numpy.asarray([-9, 328, -5082, 43596, -225645, 718116, -1363632, 1400408, -588240])/40320.,
    ]),
}

class MSM:
    Props = [
        Prop('levels', numpy.asarray, None, True),
        Prop('order', int, 4)
        ]
    def __init__(self, config, elec_const):#elec_const, ucell_lo, ucell_hi, msm_levels, cutoff, skin, order):
        msm_levels = config.nonbprops.coulprops.levels
        order = config.nonbprops.coulprops.order
        ucell_lo = config.basis * -0.5
        ucell_hi = -ucell_lo
        self.scale14 = config.nonbprops.scale14
        cutoff = config.nonbprops.cutoff #nonbsettings.get('cutoff', 12)
        skin = config.nonbprops.skin #nonbsettings.get('skin', 2)
        #setup grid related vars
        levels = msm_levels.max()
        xlevels = msm_levels[0]
        ylevels = msm_levels[1]
        zlevels = msm_levels[2]
        nmsm = numpy.zeros((levels, 3), dtype=int)
        delinv = numpy.zeros((levels, 3))
        ucell_len = ucell_hi - ucell_lo
        for i in range(levels):
            nmsm[i][0] = 2 ** (xlevels - i)
            nmsm[i][1] = 2 ** (ylevels - i)
            nmsm[i][2] = 2 ** (zlevels - i)
            nmsm[i][nmsm[i] == 0] = 1
            delinv[i] = nmsm[i] / ucell_len
        debug(nmsm)
        nlo_out = numpy.zeros((levels, 3), dtype=int)
        nhi_out = numpy.zeros((levels, 3), dtype=int)
        nlo_in = numpy.zeros((levels, 3), dtype=int)
        nhi_in = numpy.zeros((levels, 3), dtype=int)
        ndirect = (2 * cutoff * delinv[0]).astype(numpy.int)
        self.ndirect = ndirect
        # nhi_direct = (2 * cutoff * self.delinv[0]).astype(numpy.int)
        # nlo_direct = -nhi_direct

        nlower = -(order - 1) / 2
        nupper = order / 2
        for i in range(levels):
            delinv[i] = nmsm[i] / ucell_len
            coord_lo = None
            coord_hi = None
            if i == 0:
                coord_lo = ucell_lo - skin / 2
                coord_hi = ucell_hi + skin / 2
            else:
                coord_lo = ucell_lo
                coord_hi = ucell_hi
            print(coord_lo)
            print(coord_hi)
            nlo = numpy.floor((coord_lo - ucell_lo) * delinv[i]).astype(numpy.int)
            nhi = numpy.floor((coord_hi - ucell_lo) * delinv[i]).astype(numpy.int)
            nlo_out[i] = nlo - numpy.maximum(order, ndirect)
            nhi_out[i] = nhi + numpy.maximum(order, ndirect)
            nlo_in[i] = 0
            nhi_in[i] = nmsm[i] - 1
        self.elec_const = elec_const
        self.max_levels = levels
        self.nmsm = nmsm
        self.delinv = delinv
        self.nlo_out = nlo_out
        self.nhi_out = nhi_out
        self.nlo_in = nlo_in
        self.nhi_in = nhi_in
        self.ucell_lo = ucell_lo
        self.ucell_hi = ucell_hi
        self.order = order
        self.cutoff = cutoff
        self.skin = skin


        #setup gamma, dgamma, phi, dphi function
        split_order = int(order / 2)
        gcons = GCONS[split_order]
        dgcons = DGCONS[split_order]
        phicoef = PHICOEF[order]
        dphicoef = DPHICOEF[order]
        print(type(phicoef))
        cutinv = 1. / cutoff
        cut2inv = 1. / (cutoff * cutoff)

        @numba.jit(nopython=True)
        def gamma(rho):
            if rho < 1.:
                rho2 = rho * rho
                g = gcons[split_order]
                for n in range(split_order - 1, -1, -1):
                    g = g * rho2 + gcons[n]
                return g
            else:
                return 1. / rho
        @numba.jit(nopython=True)
        def dgamma(rho):
            if rho < 1.:
                rho2 = rho * rho
                dg = dgcons[split_order - 1]
                for n in range(split_order - 2, -1, -1):
                    dg = dg * rho2 + dgcons[n]
                return dg * rho
            else:
                return 1. / (rho * rho)
        half_order_floor = numpy.asarray([int((order - 1) / 2)], dtype=numpy.int64)
        print(phicoef)
        @numba.jit(nopython=True)
        def compute_phi1(d):
            absd = numpy.abs(d)
            iabsd = numpy.floor(absd).astype(numpy.int64)
            coefs = phicoef[numpy.minimum(iabsd, half_order_floor)]
            ret = coefs[:, 0]
            for i in range(1, order):
                ret = ret * absd + coefs[:, i]
            return ret

        @numba.jit(nopython=True)
        def compute_dphi1(d):
            absd = numpy.abs(d)
            #iabsd = numpy.asarray(numpy.floor(absd), dtype=numpy.int64)
            iabsd = numpy.floor(absd).astype(numpy.int64)
            coefs = dphicoef[numpy.minimum(iabsd, half_order_floor)]
            ret = coefs[:, 0]
            for i in range(1, order - 1):
                ret = ret * absd + coefs[:, i]
            return ret * numpy.sign(d)

        @numba.jit(nopython=True)
        def compute_phis(d):
            ret = numpy.zeros((order, 3))
            offset = half_order_floor
            for i in range(order):
                ret[i] = compute_phi1(d + i - offset)
            return ret
        @numba.jit(nopython=True)
        def compute_dphis(d):
            ret = numpy.zeros((order, 3))
            offset = half_order_floor
            for i in range(order):
                ret[i] = compute_dphi1(d + i - offset)
            return ret
        self.gamma = gamma
        self.dgamma = dgamma
        self.compute_phi1 = compute_phi1
        self.compute_dphi1 = compute_dphi1
        self.compute_phis = compute_phis
        self.compute_dphis = compute_dphis

        @numba.jit(nopython=True)
        def elec_short(r2, q1, q2):
            r2inv = 1. / r2
            r = numpy.sqrt(r2)
            rinv = r * r2inv
            prefactor = elec_const * q1 * q2 * rinv
            egamma = 1. - (r * cutinv) * gamma(r * cutinv)
            fgamma = 1. + (r2 * cut2inv) * dgamma(r * cutinv)
            dudr = prefactor * fgamma * r2inv
            u = prefactor * egamma
            return u, dudr

        @numba.jit(nopython=True)
        def elec_full(r2, q1, q2):
            r2inv = 1. / r2
            r = numpy.sqrt(r2)
            rinv = r * r2inv
            prefactor = elec_const * q1 * q2 * rinv
            dudr = prefactor * r2inv
            u = prefactor
            return u, dudr

        self.pair = elec_short
        self.pair_full = elec_full
        self.get_g_direct()
        self.get_v_direct()

    def get_pair(self):
        return self.pair

    def get_pair_full(self):
        return self.pair_full

    def particle2grid(self, x, q):
        order = self.order
        delinv = self.delinv[0]
        grid_lo = self.nlo_out[0]
        grid_hi = self.nhi_out[0]
        ucell_lo = self.ucell_lo

        grid_size = grid_hi - grid_lo + 1

        qgrid = numpy.zeros(grid_size[::-1])

        nlower = int(-(order - 1) / 2)
        nupper = int(order / 2)
        self.rho_log = []
        for i in range(len(q)):
            fgrid = (x[i] - ucell_lo) * delinv
            igrid = numpy.floor(fgrid).astype(numpy.int)
            dgrid = igrid - fgrid
            ogrid = igrid - grid_lo
            phis = self.compute_phis(dgrid)

            z0 = q[i]
            for iz in range(nlower, nupper + 1):
                mz = ogrid[2] + iz
                y0 = z0 * phis[iz - nlower][2]
                for iy in range(nlower, nupper + 1):
                    my = ogrid[1] + iy
                    x0 = y0 * phis[iy - nlower][1]
                    for ix in range(nlower, nupper + 1):
                        mx = ogrid[0] + ix
                        #try:
                        qgrid[mz][my][mx] += x0 * phis[ix - nlower][0]
                        self.rho_log.append((mx, my, mz, i, x0 * phis[ix - nlower][0]))
                        #except:
                        #    debug(mz, my, mx)
        self.qgrids.append(qgrid)

    def get_g_direct(self):
        self.g_direct = []
        for n in range(self.max_levels):
            nhi_direct = +self.ndirect
            nlo_direct = -self.ndirect
            ksize_n = nhi_direct - nlo_direct + 1
            g_direct_n = numpy.zeros(ksize_n[::-1])
            twon = 2 ** n
            a = self.cutoff
            delinv = self.delinv[n]

            for iz in range(nlo_direct[2], nhi_direct[2] + 1):
                zdiff = iz / delinv[2]
                for iy in range(nlo_direct[1], nhi_direct[1] + 1):
                    ydiff = iy / delinv[1]
                    for ix in range(nlo_direct[0], nhi_direct[0] + 1):
                        xdiff = ix / delinv[0]
                        d = numpy.asarray([xdiff, ydiff, zdiff])
                        rsq = d.dot(d)
                        c = tuple((numpy.asarray([ix, iy, iz], dtype=int) - nlo_direct)[::-1])
                        rho = numpy.sqrt(rsq) / (twon * a)

                        g_direct_n[c] = self.gamma(rho)/(twon*a) - self.gamma(rho/2.)/(2.*twon*a)
            self.g_direct.append(g_direct_n)

    def get_v_direct(self):
        self.v_direct = []
        for n in range(self.max_levels):
            nhi_direct = +self.ndirect
            nlo_direct = -self.ndirect
            ksize_n = nhi_direct - nlo_direct + 1
            v_direct_n = numpy.zeros((6,) + tuple(ksize_n[::-1]))
            twon = 2 ** n
            a = self.cutoff
            delinv = self.delinv[n]

            twon = 2 ** n
            fourn = twon ** 2
            a = self.cutoff
            asq = a ** 2
            for iz in range(nlo_direct[2], nhi_direct[2] + 1):
                dz = iz / delinv[2]
                for iy in range(nlo_direct[1], nhi_direct[1] + 1):
                    dy = iy / delinv[1]
                    for ix in range(nlo_direct[0], nhi_direct[0] + 1):
                        dx = ix / delinv[0]
                        d = numpy.asarray([dx, dy, dz])
                        rsq = d.dot(d)
                        r = numpy.sqrt(rsq)
            
                        c = tuple((numpy.asarray([ix, iy, iz], dtype=int) - nlo_direct)[::-1])
            
                        if r != 0:
                            rho = r / (twon * a)
                            dg = -(self.dgamma(rho)/(fourn*asq) - self.dgamma(rho/2.)/(4*fourn*asq))/r;
                            v_direct_n[0][c] = dg*dx*dx
                            v_direct_n[1][c] = dg*dy*dy
                            v_direct_n[2][c] = dg*dz*dz
                            v_direct_n[3][c] = dg*dx*dy
                            v_direct_n[4][c] = dg*dx*dz
                            v_direct_n[5][c] = dg*dy*dz
            self.v_direct.append(v_direct_n)
        
    def direct(self, n, needvir = False):
        cutoff = self.cutoff
        order = self.order
        nhi_direct = (2 * cutoff * self.delinv[0]).astype(numpy.int)
        nlo_direct = -nhi_direct
        qgrid = self.qgrids[n]
        egrid = numpy.zeros(qgrid.shape)
        nmsm = self.nmsm[n]
        nlo_out = self.nlo_out[n]
        g_direct = self.g_direct[n]
        v_direct = self.v_direct[n]

        #edir = 0
        virial = numpy.zeros(6)

        for iz in range(0, nmsm[2]):
            gz = iz - nlo_out[2]
            for iy in range(0, nmsm[1]):
                gy = iy - nlo_out[1]
                for ix in range(0, nmsm[0]):
                    gx = ix - nlo_out[0]
                    cl = (numpy.asarray([gx, gy, gz]) + nlo_direct)[::-1]
                    ch = (numpy.asarray([gx, gy, gz]) + nhi_direct)[::-1] + 1
                    c = tuple(map(lambda x: slice(x[0], x[1]), zip(cl, ch)))

                    esum = (g_direct * qgrid[c]).sum()
                    if (iz == 0 and iy == 0 and ix == 0):
                        print(esum)
                        import code
                        code.interact(local=locals())
                    egrid[gz, gy, gx] = esum
                    self.edir += esum * qgrid[gz, gy, gx]
                    if needvir:
                        vsum = (v_direct * qgrid[c]).sum()
                        virial += vsum * qgrid[gz, gy, gx]

        self.egrids.append(egrid)
        if needvir:
            self.vgrids.append(vgrid)

    def reverse_pbc_grid(self, grid, n, inplace=False):
        nout = self.nhi_out[n] - self.nlo_out[n] + 1
        nin = self.nhi_in[n] - self.nlo_in[n] + 1

        hlo = self.nlo_in[n] - self.nlo_out[n]
        hhi = self.nhi_out[n] - self.nhi_in[n]

        grid_out_z = grid if inplace else numpy.copy(grid)
        grid_out_y = grid_out_z[hlo[2]:nout[2] - hhi[2]]
        grid_out_x = grid_out_y[:, hlo[1]:nout[1]-hhi[1]]
    
        for iz in range(hlo[2]):
            grid_out_z[iz + nin[2], :, :] += grid_out_z[iz, :, :]
        for iz in range(nout[2] - 1, nout[2] - hhi[2] - 1, -1):
            grid_out_z[iz - nin[2], :, :] += grid_out_z[iz, :, :]
    
        debug(hlo, nin)
        for iy in range(hlo[1]):
            grid_out_y[:, iy + nin[1], :] += grid_out_y[:, iy, :]
        for iy in range(nout[1] - 1, nout[1] - hhi[1] - 1, -1):
            grid_out_y[:, iy - nin[1], :] += grid_out_y[:, iy, :]
    
        for ix in range(hlo[0]):
            grid_out_x[:, :, ix + nin[0]] += grid_out_x[:, :, ix]
        for ix in range(nout[0] - 1, nout[0] - hhi[0] - 1, -1):
            grid_out_x[:, :, ix - nin[0]] += grid_out_x[:, :, ix]
    
        return grid

    def forward_pbc_grid(self, grid, n, inplace=False):
        nout = self.nhi_out[n] - self.nlo_out[n] + 1
        nin = self.nhi_in[n] - self.nlo_in[n] + 1

        hlo = self.nlo_in[n] - self.nlo_out[n]
        hhi = self.nhi_out[n] - self.nhi_in[n]

        grid_out_z = grid if inplace else numpy.copy(grid)
        grid_out_y = grid_out_z[hlo[2]:nout[2]-hhi[2]]
        grid_out_x = grid_out_y[:, hlo[1]:nout[1]-hhi[1]]
        for ix in range(hlo[0] - 1, -1, -1):
            grid_out_x[:, :, ix] = grid_out_x[:, :, ix + nin[0]]
        for ix in range(nout[0] - hhi[0], nout[0]):
            grid_out_x[:, :, ix] = grid_out_x[:, :, ix - nin[0]]
    
        for iy in range(hlo[1] - 1, -1, -1):
            grid_out_y[:, iy, :] = grid_out_y[:, iy + nin[1], :]
        for iy in range(nout[1] - hhi[1], nout[1]):
            grid_out_y[:, iy, :] = grid_out_y[:, iy - nin[1], :]
    
        for iz in range(hlo[2] - 1, -1, -1):
            grid_out_z[iz, :, :] = grid_out_z[iz + nin[2], :, :]
        for iz in range(nout[2] - hhi[2], nout[2]):
            grid_out_z[iz, :, :] = grid_out_z[iz - nin[2], :, :]
    
        return grid_out_z

    def wrap_pbc_grid(self, grid, n, inplace=False):
        rev = self.reverse_pbc_grid(grid, n, inplace)
        fwd = self.forward_pbc_grid(rev, n, inplace)
        return fwd
    # def wrap_pbc(grid_in, nlo_out, nhi_out, nlo_in, nhi_in):
    #     added = add_outer_to_inner(grid_in, nlo_out, nhi_out, nlo_in, nhi_in)
    #     copied = copy_inner_to_outer(added, nlo_out, nhi_out, nlo_in, nhi_in)
    #     return copied

    def restriction(self, n):
        qgrid = self.qgrids[n]
        order = self.order
        p = order - 1
        nlo_outn = self.nlo_out[n].astype(numpy.int)
        qgrid_out = numpy.zeros((self.nhi_out[n + 1] - self.nlo_out[n + 1] + 1)[::-1])
        debug(qgrid_out.shape)
        k = 0
        k2nu = numpy.zeros(order + 1, dtype=int)
        phi1d = numpy.zeros((order + 1, 3))
        delinvn = self.delinv[n]
        delinvnp1 = self.delinv[n + 1]
        # nu in -3, 3: -3, -1, 0, 1, 3
        for nu in range(-p, p + 1):
            if nu % 2 == 0 and nu != 0:
                continue
            phi1d[k] = self.compute_phi1(nu * delinvnp1 / delinvn)
            k2nu[k] = nu
            k += 1
    
        debug(phi1d[:, 0])
        nlo_innp1 = self.nlo_in[n + 1]
        nhi_innp1 = self.nhi_in[n + 1]
        nlo_outnp1 = self.nlo_out[n + 1]
        nhi_outnp1 = self.nhi_out[n + 1]
    
        for iz in range(nlo_innp1[2], nhi_innp1[2] + 1):
            for iy in range(nlo_innp1[1], nhi_innp1[1] + 1):
                for ix in range(nlo_innp1[0], nhi_innp1[0] + 1):
                    c = numpy.asarray([ix, iy, iz], dtype=int) * numpy.floor(delinvn/delinvnp1 + 1e-8).astype(numpy.int)
                    q2sum = 0
                    for dz in range(0, p + 2):
                        kk = c[2] + k2nu[dz]
                        #debug(c, k2nu)
                        phiz = phi1d[dz][2]
                        for dy in range(0, p + 2):
                            jj = c[1] + k2nu[dy]
                            phizy = phi1d[dy][2] * phiz
                            for dx in range(0, p + 2):
                                ii = c[0] + k2nu[dx]
                                q2sum += qgrid[kk-nlo_outn[2], jj-nlo_outn[1], ii-nlo_outn[0]] * phi1d[dx][0] * phizy
                    qgrid_out[iz-nlo_outnp1[2], iy-nlo_outnp1[1], ix-nlo_outnp1[0]] += q2sum
        self.qgrids.append(qgrid_out)
        #return qgrid_out

    
    def prolongation(self, n): #, egrid_in, egrid_out, gridconf, n):
        egrid_in = self.egrids[n + 1]
        egrid_out = self.egrids[n]
        order = self.order
        p = order - 1
        nlo_outn = self.nlo_out[n].astype(numpy.int)
        #egrid_out = numpy.zeros((self.nhi_out[n + 1] - self.nlo_out[n + 1] + 1)[::-1])
        debug(egrid_out.shape)
        k = 0
        k2nu = numpy.zeros(order + 1, dtype=int)
        phi1d = numpy.zeros((order + 1, 3))
        delinvn = self.delinv[n]
        delinvnp1 = self.delinv[n + 1]
        #nu in -3, 4 [-3, -1, 0, 1, 3]
        for nu in range(-p, p + 1):
            if nu % 2 == 0 and nu != 0:
                continue
            phi1d[k] = self.compute_phi1(nu * delinvnp1 / delinvn)
            k2nu[k] = nu
            k += 1

        debug(phi1d[:, 0])
        nlo_innp1 = self.nlo_in[n + 1]
        nhi_innp1 = self.nhi_in[n + 1]
        nlo_outnp1 = self.nlo_out[n + 1]
        nhi_outnp1 = self.nhi_out[n + 1]
    
        for iz in range(nlo_innp1[2], nhi_innp1[2] + 1):
            for iy in range(nlo_innp1[1], nhi_innp1[1] + 1):
                for ix in range(nlo_innp1[0], nhi_innp1[0] + 1):
                    c = numpy.asarray([ix, iy, iz], dtype=int) * numpy.floor(delinvn/delinvnp1 + 1e-8).astype(numpy.int)
                    etmp = egrid_in[iz - nlo_outnp1[2], iy - nlo_outnp1[1], ix - nlo_outnp1[0]]
                    #q2sum = 0
                    for dz in range(0, p + 2):
                        kk = c[2] + k2nu[dz]
                        #debug(c, k2nu)
                        phiz = phi1d[dz][2]
                        for dy in range(0, p + 2):
                            jj = c[1] + k2nu[dy]
                            phizy = phi1d[dy][2] * phiz
                            for dx in range(0, p + 2):
                                ii = c[0] + k2nu[dx]
                                phi3d = phizy * phi1d[dx][0]
                                egrid_out[kk-nlo_outn[2], jj-nlo_outn[1], ii-nlo_outn[0]] += etmp * phi3d
                                #egrid_out[iz-nlo_outnp1[2], iy-nlo_outnp1[1], ix-nlo_outnp1[0]] += q2sum
        #return egrid_out
    
    def grid2particle(self, x, q, f): #, egrid, x, q, f, gridconf, elec_const=14.399645):
        egrid = self.egrids[0]
        order = self.order
        delinv = self.delinv[0]
        grid_lo = self.nlo_out[0]
        grid_hi = self.nhi_out[0]
        ucell_lo = self.ucell_lo
        grid_size = grid_hi - grid_lo + 1
        debug(grid_lo, grid_hi, grid_size)
        part2grid = numpy.floor(x * delinv)
        qgrid = numpy.zeros(grid_size[::-1])
        debug(qgrid.shape)
        nlower = int(-(order - 1) / 2)
        nupper = int(order / 2)
        debug(nupper - nlower)
        log = []
        for i in range(len(q)):
            fgrid = (x[i] - ucell_lo) * delinv
            igrid = numpy.floor(fgrid).astype(numpy.int)
            dgrid = igrid - fgrid
            ogrid = igrid - grid_lo
            phis = self.compute_phis(dgrid)
            dphis = self.compute_dphis(dgrid)
            #return
            #z0 = q[i]
            e = numpy.zeros(3)
            for iz in range(nlower, nupper + 1):
                mz = ogrid[2] + iz
                #y0 = z0 * phis[iz - nlower][2]
                phiz = phis[iz - nlower][2]
                dphiz = dphis[iz - nlower][2]
                for iy in range(nlower, nupper + 1):
                    my = ogrid[1] + iy
                    #x0 = y0 * phis[iy - nlower][1]
                    phiy = phis[iy - nlower][1]
                    dphiy = dphis[iy - nlower][1]
                    for ix in range(nlower, nupper + 1):
                        mx = ogrid[0] + ix
                        #log.append((i, mz, my, mx, x0 * phis[ix - nlower][0]))
                        phix = phis[ix - nlower][0]
                        dphix = dphis[ix - nlower][0]
                        etmp = egrid[mz][my][mx]
                        e[0] += dphix*phiy*phiz*etmp
                        e[1] += phix*dphiy*phiz*etmp
                        e[2] += phix*phiy*dphiz*etmp
            e *= delinv
            qfactor = self.elec_const*q[i]
            f[i] += e * qfactor
        return f
    def compute(self, f, x, q):
        self.edir = 0
        self.qgrids = []
        self.egrids = []

        self.particle2grid(x, q)
        self.tqgrid = self.qgrids[0].copy()
        for i in range(0, self.max_levels):
            #qgrids[i] = self.wrap_pbc_gc(qgrids[i], gridconf, i)
            self.wrap_pbc_grid(self.qgrids[i], i, inplace=True)
            self.direct(i) #qgrids[i], gridconf, i, True)
            print(self.edir)
            #egrids.append(egridi)
            if i != self.max_levels - 1:
                self.restriction(i) #qgrids[i], gridconf, i)
                #qgrids.append(qgridip1)
        self.tegrid = self.egrids[0].copy()
        #egrids[nlevels - 1] = msm.wrap_pbc_gc(egrids[nlevels - 1], gridconf, nlevels - 1)
        for i in range(self.max_levels - 2, -1, -1):
            self.wrap_pbc_grid(self.egrids[i + 1], i + 1, inplace = True)
            self.prolongation(i) #egrids[i+1], egrids[i], gridconf, i)
            
        self.wrap_pbc_grid(self.egrids[0], 0, inplace = True)

        self.grid2particle(x, q, f)
        qsqsum = (q * q).sum()
        e_self = qsqsum*self.gamma(0.0)/self.cutoff
        energy = (self.edir - e_self) * 0.5 * self.elec_const
        print(qsqsum, self.edir, e_self)
        print(energy)
        return energy
if __name__ == '__main__':
    import numpy
    import struct
    def load_f64_array(path, shape):
        return numpy.asarray(list(struct.iter_unpack('d', open(path, 'rb').read()))).reshape(shape)

    def load_i32_array(path, shape):
        return numpy.asarray(list(struct.iter_unpack('i', open(path, 'rb').read()))).reshape(shape)

    A = lambda x: numpy.asarray(x)

    x = load_f64_array("../../qeq/x.bin", (1680, 3))
    q = load_f64_array("../../qeq/q.bin", (1680))
    f = load_f64_array("../../qeq/f_before_ff_o4.bin", (1680, 3))
    fstd = load_f64_array("../../qeq/f_after_ff_o4.bin", (1680, 3))
    f_before_ff = load_f64_array("../../qeq/f_before_ff.bin", (1680, 3))
    M = MSM(14.399645, A((0, 0, 0)), A((100, 100, 25)), A((3, 3, 2)), 12, 1, 4)
    ffinal = f.copy()
    M.compute(x, q, ffinal)
    pair = M.get_pair()
    @numba.jit(nopython=True)
    def eval_pair(x, q, nlocal):
        f = numpy.zeros((nlocal, 3))
        for i in range(nlocal):
            for j in range(len(x)):
                if i == j:
                    continue
                d = x[i] - x[j]
                r2 = d.dot(d)
                if r2 < 144:
                    epair, fpair = pair(r2, q[i], q[j])
                    f[i] += d * fpair
                    #f[j] -= d * fpair
        return f
    def wrap_pbc_atoms(xlocal, qlocal, boxlo, boxhi, dist):
        boxsize = boxhi - boxlo
        x = xlocal
        q = qlocal
        for dim in range(3):
            atom_lo = x[:, dim] - dist < boxlo[dim]
            atom_hi = x[:, dim] + dist > boxhi[dim]
            offset = numpy.zeros(3)
            offset[dim] = boxsize[dim]
            x = numpy.concatenate((x, x[atom_lo, :] + offset, x[atom_hi, :] - offset))
            q = numpy.concatenate((q, q[atom_lo], q[atom_hi]))
        return x, q

    print(((fstd - ffinal) ** 2).mean())
